In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
In [2]:
import fastai as ai
import fastai
from fastai.vision import *
from fastai.callbacks import *
from fastai.vision.gan import *
from PIL import Image, ImageDraw
import shutil
import requests, zipfile, io
import os
from tqdm import tqdm
import dropbox

Prepare data

Paths

In [3]:
path = Path("/storage/capstone/")
color = 'color'
grayscale = 'grayscale'
colorized = 'colorized'
path_color = path/color
path_gray = path/grayscale
path_gen = path/colorized

if not path_color.exists():
    path_color.mkdir(parents=True)
if not path_gray.exists():
    path_gray.mkdir(parents=True)

path.ls()
Out[3]:
[PosixPath('/storage/capstone/grayscale'),
 PosixPath('/storage/capstone/WGU_Capstone.ipynb'),
 PosixPath('/storage/capstone/color'),
 PosixPath('/storage/capstone/colorized'),
 PosixPath('/storage/capstone/WGU_Capstone-MSE.ipynb'),
 PosixPath('/storage/capstone/models'),
 PosixPath('/storage/capstone/.ipynb_checkpoints')]

Download data

In [4]:
def download_zip(url, path, fn, chunk_size=128):
    r = requests.get(url, stream=True)
    if r.ok:
        with open(path/fn, 'wb') as f:
            for chunk in r.iter_content(chunk_size=chunk_size):
                f.write(chunk)
    else:
        print("request failed")
In [5]:
# download 2017 coco data from http://cocodataset.org/#download
#url = "http://images.cocodataset.org/zips/train2017.zip"
#zip_fn = 'coco.zip'
#download_zip(url, path_color, zip_fn, 1024*1024*100)
In [6]:
# unzip file
#with zipfile.ZipFile(path_color/zip_fn) as z:
#    z.extractall(path_color)
In [7]:
# move images from extract folder to parent folder
#for fp in (path_color/'train2017').ls():
#    shutil.move(fp, path_color/fp.name)
In [8]:
#os.remove(path_color/zip_fn)

Make grayscale copies of images

In [9]:
def bnw(fn, i, dest):
    try:
        im = Image.open(fn)
        im.verify()
    except:
        return
    im = Image.open(fn)
    im = im.convert('L')
    im.save(dest/fn.name, quality=100)
In [10]:
# save data to storage (color and grayscale copies)
#il = ImageList.from_folder(path_color)
#ai.core.parallel(partial(bnw, dest=path_gray), il.items)

Set up generator data

In [11]:
def get_generator_data(bs, size, p=1.):
    # data source
    label_func = lambda x : path_color/x.name
    src = (ImageImageList
           .from_folder(path_gray).use_partial_data(p)
           .split_by_rand_pct(0.1)
           .label_from_func(label_func))
    # data bunch
    data = src.transform(tfms=get_transforms(), size=size, tfm_y=True).databunch(bs=bs).normalize(imagenet_stats, do_y=True)
    data.c = 3
    return data

Set up critic data

In [12]:
# used to save predictions from generator
def save_preds(loader, learner, dest):
    if not dest.exists():
        dest.mkdir(parents=True) 
    i=0
    images = loader.dataset.items
    for batch in loader:
        preds = learner.pred_batch(batch=batch, reconstruct=True)
        for p in preds:
            p.save(dest/images[i].name)
            i += 1
In [13]:
def get_critic_data(classes, bs, size):
    # data source
    src = (ImageList
           .from_folder(path, include=classes)
           .split_by_rand_pct(0.1)
           .label_from_folder(classes=classes))
    # data bunch
    data = src.transform(tfms=get_transforms(), size=size).databunch(bs=bs).normalize(imagenet_stats)
    data.c = 3
    return data

Set data parameters

In [14]:
bs = 48
size = 128

Set up models

In [15]:
arch = models.resnet34
loss_gen = MSELossFlat()
wd = 1e-3
def get_generator(data_gen):
    return unet_learner(data_gen, arch, loss_func=loss_gen, 
                        wd=wd, blur=True, norm_type=NormType.Weight, 
                        self_attention=True)
In [16]:
loss_crit = AdaptiveLoss(nn.BCEWithLogitsLoss())
def get_critic(data_crit, metrics):
    return Learner(data_crit, gan_critic(), metrics=metrics, 
                   loss_func=loss_crit, wd=wd)

Pre-train generative model

In [ ]:
data_gen = get_generator_data(bs, size)
data_gen.show_batch(ds_type=DatasetType.Train, rows=1)
In [18]:
generator = get_generator(data_gen)
In [19]:
generator.fit_one_cycle(3, 1e-3)
epoch train_loss valid_loss time
0 0.118355 0.119353 16:43
1 0.107789 0.110806 17:07
2 0.104482 0.107539 17:07
In [20]:
generator.unfreeze()
In [21]:
generator.lr_find()
generator.recorder.plot()
0.00% [0/1 00:00<00:00]
epoch train_loss valid_loss time

3.20% [71/2217 00:33<16:39 0.1969]
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
In [22]:
generator.fit_one_cycle(6, slice(1e-6,1e-4))
epoch train_loss valid_loss time
0 0.107941 0.107307 17:18
1 0.103494 0.106698 17:42
2 0.106161 0.105904 17:31
3 0.102731 0.105520 18:31
4 0.099933 0.105044 18:06
5 0.102696 0.105270 18:08
In [23]:
generator.show_results(rows=2)
In [24]:
generator.save('gen-pre')

Save predictions

In [25]:
save_preds(data_gen.fix_dl, generator, path_gen)
generator = None
gc.collect()
Out[25]:
17935

Pre-train critic

In [17]:
data_crit = get_critic_data([colorized, color], bs=bs, size=size)
data_crit.show_batch(ds_type=DatasetType.Train, rows=2)
In [18]:
critic = get_critic(data_crit, accuracy_thresh_expand)
In [19]:
critic.fit_one_cycle(6, 1e-3)
epoch train_loss valid_loss accuracy_thresh_expand time
0 0.114383 0.154046 0.938306 47:39
1 0.064159 0.056589 0.977996 47:45
2 0.038673 0.042749 0.986632 47:33
3 0.013506 0.016955 0.994756 47:29
4 0.007489 0.012178 0.996099 47:29
5 0.015255 0.010709 0.996598 47:23
In [20]:
critic.save('critic-pre')
critic = None
gc.collect()
Out[20]:
8015

GAN

In [17]:
def refresh_gan(version, crit_thresh=0.65, loss_weights=(1.,50.), bs=48, size=128, p=1.):       
    data_gen = get_generator_data(bs, size, p)
    data_crit = get_critic_data([grayscale, color], bs=bs, size=size)
    generator = get_generator(data_gen)
    critic = get_critic(data_crit, metrics=None)
    switcher = partial(AdaptiveGANSwitcher, critic_thresh=crit_thresh)
    if version == 'pre':
        generator.load('gen-pre')
        critic.load('critic-pre')
        return GANLearner.from_learners(generator, critic, weights_gen=loss_weights, 
                                        show_img=True, switcher=switcher,
                                        opt_func=optim.Adam, wd=wd)
    return GANLearner.from_learners(generator, critic, weights_gen=loss_weights, 
                                    show_img=True, switcher=switcher,
                                    opt_func=optim.Adam, wd=wd).load(version)
In [18]:
lr = 1e-4
In [19]:
learn = refresh_gan('pre')
gc.collect()
learn.fit(5, lr)
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-19-bda027748263> in <module>
----> 1 learn = refresh_gan('pre')
      2 gc.collect()
      3 learn.fit(5, lr)

<ipython-input-17-9f609cb4e932> in refresh_gan(version, crit_thresh, loss_weights, bs, size)
      1 def refresh_gan(version, crit_thresh=0.65, loss_weights=(1.,50.), bs=48, size=128):
----> 2     data_gen = get_generator_data(bs, size)
      3     data_crit = get_critic_data([grayscale, color], bs=bs, size=size)
      4     generator = get_generator(data_gen)
      5     critic = get_critic(data_crit, metrics=None)

<ipython-input-11-89deec1480e9> in get_generator_data(bs, size)
      3     label_func = lambda x : path_color/x.name
      4     src = (ImageImageList
----> 5            .from_folder(path_gray)
      6            .split_by_rand_pct(0.1)
      7            .label_from_func(label_func))

/opt/conda/envs/fastai/lib/python3.6/site-packages/fastai/vision/data.py in from_folder(cls, path, extensions, **kwargs)
    277         "Get the list of files in `path` that have an image suffix. `recurse` determines if we search subfolders."
    278         extensions = ifnone(extensions, image_extensions)
--> 279         return super().from_folder(path=path, extensions=extensions, **kwargs)
    280 
    281     @classmethod

/opt/conda/envs/fastai/lib/python3.6/site-packages/fastai/data_block.py in from_folder(cls, path, extensions, recurse, exclude, include, processor, presort, **kwargs)
    127         `recurse` determines if we search subfolders."""
    128         path = Path(path)
--> 129         return cls(get_files(path, extensions, recurse=recurse, exclude=exclude, include=include, presort=presort), 
    130                    path=path, processor=processor, **kwargs)
    131 

/opt/conda/envs/fastai/lib/python3.6/site-packages/fastai/data_block.py in get_files(path, extensions, recurse, exclude, include, presort, followlinks)
     33     if recurse:
     34         res = []
---> 35         for i,(p,d,f) in enumerate(os.walk(path, followlinks=followlinks)):
     36             # skip hidden dirs
     37             if include is not None and i==0:   d[:] = [o for o in d if o in include]

/opt/conda/envs/fastai/lib/python3.6/os.py in walk(top, topdown, onerror, followlinks)
    365 
    366             try:
--> 367                 is_dir = entry.is_dir()
    368             except OSError:
    369                 # If is_dir() raises an OSError, consider that the entry is not

KeyboardInterrupt: 
In [21]:
learn.save('gan-128-5')
In [20]:
learn = refresh_gan('gan-128-5')
gc.collect()
learn.fit(5, lr)
epoch train_loss valid_loss gen_loss disc_loss time
0 3.122432 3.403421 6.439204 0.596922 35:58
1 3.330484 3.362315 6.302394 0.569140 35:41
2 3.278370 3.380139 6.430364 0.578642 36:09
3 3.184355 3.368073 6.442218 0.579640 35:57
4 3.377575 3.371960 6.297174 0.569742 35:52
In [21]:
learn.save('gan-128-10')
In [21]:
learn = refresh_gan('gan-128-10', bs=bs*2//3, size=192)
gc.collect()
learn.fit(2, lr)
epoch train_loss valid_loss gen_loss disc_loss time
0 3.335786 3.416227 6.386610 0.559466 1:32:33
1 3.411482 3.402645 6.402471 0.566634 1:32:50
In [22]:
learn.save('gan-192-2')
In [20]:
learn = refresh_gan('gan-192-2', bs=bs*2//3, size=192)
gc.collect()
learn.fit(3, lr)
epoch train_loss valid_loss gen_loss disc_loss time
0 3.470173 3.335995 6.444449 0.562002 1:32:52
1 3.224054 3.396153 6.232610 0.548251 1:32:43
2 3.479856 3.266636 6.432885 0.564621 1:32:43
In [21]:
learn.save('gan-192-5')
In [24]:
learn = refresh_gan('gan-192-5', bs=bs*2//3, size=192)
gc.collect()
learn.fit(2, lr)
epoch train_loss valid_loss gen_loss disc_loss time
0 3.388874 3.368306 6.443987 0.572757 1:41:03
1 3.203342 3.405502 6.281001 0.577639 1:39:58
In [25]:
learn.save('gan-192-7')
In [26]:
learn = refresh_gan('gan-192-5', bs=bs//3, size=256)
gc.collect()
learn.fit(1, lr/2)
epoch train_loss valid_loss gen_loss disc_loss time
0 3.190770 3.349433 6.251172 0.534610 2:59:06
In [27]:
learn.save('gan-256-1')
In [19]:
learn = refresh_gan('gan-256-1', bs=bs//3, size=256)
gc.collect()
learn.fit(1, lr/2)
epoch train_loss valid_loss gen_loss disc_loss time
0 3.369833 3.397354 6.490441 0.567383 2:49:40
In [20]:
learn.save('gan-256-2')
In [22]:
learn = refresh_gan('gan-256-2', bs=bs//3, size=256)
gc.collect()
learn.fit(1, lr/2)
epoch train_loss valid_loss gen_loss disc_loss time
0 3.526824 3.300664 6.608747 0.541469 2:44:57
In [23]:
learn.save('gan-256-3')
In [21]:
learn = refresh_gan('gan-256-3', bs=bs//3, size=256)
gc.collect()
learn.fit(1, lr/2)
epoch train_loss valid_loss gen_loss disc_loss time
0 3.486530 3.381495 6.549844 0.544487 2:44:57
In [22]:
learn.save('gan-256-4')
In [23]:
learn = refresh_gan('gan-256-4', bs=bs//3, size=256)
gc.collect()
learn.fit(1, lr/2)
epoch train_loss valid_loss gen_loss disc_loss time
0 3.481686 3.355738 6.621835 0.547295 2:45:04
In [24]:
learn.save('gan-256-5')
In [19]:
learn = refresh_gan('gan-256-5', bs=bs//5, size=320)
gc.collect()
learn.fit(1, lr/4)
epoch train_loss valid_loss gen_loss disc_loss time
0 3.518986 3.300133 6.758544 0.563619 4:19:11
In [20]:
learn.save('gan-320-1')
In [26]:
learn = refresh_gan('gan-320-1', bs=bs//6, size=320, p=0.8)
gc.collect()
learn.fit(1, lr/4)
0.00% [0/1 00:00<00:00]
epoch train_loss valid_loss gen_loss disc_loss time

0.11% [12/10645 00:20<5:00:10 2.6060]
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-26-0966792f871c> in <module>
      1 learn = refresh_gan('gan-320-1', bs=bs//6, size=320, p=0.8)
      2 gc.collect()
----> 3 learn.fit(1, lr/4)

/opt/conda/envs/fastai/lib/python3.6/site-packages/fastai/basic_train.py in fit(self, epochs, lr, wd, callbacks)
    198         else: self.opt.lr,self.opt.wd = lr,wd
    199         callbacks = [cb(self) for cb in self.callback_fns + listify(defaults.extra_callback_fns)] + listify(callbacks)
--> 200         fit(epochs, self, metrics=self.metrics, callbacks=self.callbacks+callbacks)
    201 
    202     def create_opt(self, lr:Floats, wd:Floats=0.)->None:

/opt/conda/envs/fastai/lib/python3.6/site-packages/fastai/basic_train.py in fit(epochs, learn, callbacks, metrics)
     99             for xb,yb in progress_bar(learn.data.train_dl, parent=pbar):
    100                 xb, yb = cb_handler.on_batch_begin(xb, yb)
--> 101                 loss = loss_batch(learn.model, xb, yb, learn.loss_func, learn.opt, cb_handler)
    102                 if cb_handler.on_batch_end(loss): break
    103 

/opt/conda/envs/fastai/lib/python3.6/site-packages/fastai/basic_train.py in loss_batch(model, xb, yb, loss_func, opt, cb_handler)
     28 
     29     if not loss_func: return to_detach(out), to_detach(yb[0])
---> 30     loss = loss_func(out, *yb)
     31 
     32     if opt is not None:

/opt/conda/envs/fastai/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    530             result = self._slow_forward(*input, **kwargs)
    531         else:
--> 532             result = self.forward(*input, **kwargs)
    533         for hook in self._forward_hooks.values():
    534             hook_result = hook(self, input, result)

/opt/conda/envs/fastai/lib/python3.6/site-packages/fastai/vision/gan.py in forward(self, *args)
     46 
     47     def forward(self, *args):
---> 48         return self.generator(*args) if self.gen_mode else self.critic(*args)
     49 
     50     def switch(self, gen_mode:bool=None):

/opt/conda/envs/fastai/lib/python3.6/site-packages/fastai/vision/gan.py in critic(self, real_pred, input)
     66         "Create some `fake_pred` with the generator from `input` and compare them to `real_pred` in `self.loss_funcD`."
     67         fake = self.gan_model.generator(input.requires_grad_(False)).requires_grad_(True)
---> 68         fake_pred = self.gan_model.critic(fake)
     69         return self.loss_funcC(real_pred, fake_pred)
     70 

/opt/conda/envs/fastai/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    530             result = self._slow_forward(*input, **kwargs)
    531         else:
--> 532             result = self.forward(*input, **kwargs)
    533         for hook in self._forward_hooks.values():
    534             hook_result = hook(self, input, result)

/opt/conda/envs/fastai/lib/python3.6/site-packages/torch/nn/modules/container.py in forward(self, input)
     98     def forward(self, input):
     99         for module in self:
--> 100             input = module(input)
    101         return input
    102 

/opt/conda/envs/fastai/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    530             result = self._slow_forward(*input, **kwargs)
    531         else:
--> 532             result = self.forward(*input, **kwargs)
    533         for hook in self._forward_hooks.values():
    534             hook_result = hook(self, input, result)

/opt/conda/envs/fastai/lib/python3.6/site-packages/torch/nn/modules/container.py in forward(self, input)
     98     def forward(self, input):
     99         for module in self:
--> 100             input = module(input)
    101         return input
    102 

/opt/conda/envs/fastai/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    522     def __call__(self, *input, **kwargs):
    523         for hook in self._forward_pre_hooks.values():
--> 524             result = hook(self, input)
    525             if result is not None:
    526                 if not isinstance(result, tuple):

/opt/conda/envs/fastai/lib/python3.6/site-packages/torch/nn/utils/spectral_norm.py in __call__(self, module, inputs)
     97 
     98     def __call__(self, module, inputs):
---> 99         setattr(module, self.name, self.compute_weight(module, do_power_iteration=module.training))
    100 
    101     def _solve_v_and_rescale(self, weight_mat, u, target_sigma):

/opt/conda/envs/fastai/lib/python3.6/site-packages/torch/nn/utils/spectral_norm.py in compute_weight(self, module, do_power_iteration)
     83                     v = v.clone(memory_format=torch.contiguous_format)
     84 
---> 85         sigma = torch.dot(u, torch.mv(weight_mat, v))
     86         weight = weight / sigma
     87         return weight

KeyboardInterrupt: 
In [ ]:
learn.save('gan-320-2')
In [19]:
learn = refresh_gan('gan-320-2', bs=bs//6, size=320, p=0.8)
gc.collect()
learn.fit(1, lr/4)
epoch train_loss valid_loss gen_loss disc_loss time
0 3.500950 3.086098 6.882501 0.557121 4:51:40
In [20]:
learn.save('gan-320-3')
In [19]:
learn = refresh_gan('gan-320-3', bs=bs//6, size=320, p=0.8)
gc.collect()
learn.fit(1, lr/4)
epoch train_loss valid_loss gen_loss disc_loss time
0 2.981765 3.009394 6.470855 0.594498 4:52:15
In [20]:
learn.save('gan-320-4')
In [19]:
learn = refresh_gan('gan-320-4', bs=bs//6, size=320, p=0.8)
gc.collect()
learn.fit(1, lr/4)
epoch train_loss valid_loss gen_loss disc_loss time
0 3.164691 3.170581 6.320187 0.545949 4:45:15
In [20]:
learn.save('gan-320-5')
In [19]:
learn = refresh_gan('gan-320-5', bs=bs//6, size=320, p=0.8)
gc.collect()
learn.fit(1, lr/8)
epoch train_loss valid_loss gen_loss disc_loss time
0 3.311751 3.173258 6.527965 0.586228 4:52:08
In [20]:
learn.save('gan-320-6')
In [20]:
learn = refresh_gan('gan-320-6', bs=bs//6, size=320, p=0.8)
gc.collect()
learn.fit(1, lr/8)
epoch train_loss valid_loss gen_loss disc_loss time
0 3.213613 3.150085 6.288836 0.558234 4:48:19
In [21]:
learn.save('gan-320-7')
In [19]:
learn = refresh_gan('gan-320-7', bs=bs//6, size=320, p=0.8)
gc.collect()
learn.fit(1, lr/8)
epoch train_loss valid_loss gen_loss disc_loss time
0 3.270710 3.218359 6.367918 0.570983 4:52:29
In [20]:
learn.save('gan-320-8')
In [19]:
learn = refresh_gan('gan-320-8', bs=bs//6, size=320, p=0.8)
gc.collect()
learn.fit(1, lr/8)
epoch train_loss valid_loss gen_loss disc_loss time
0 3.112051 3.034698 6.280403 0.560253 4:47:58
In [20]:
learn.save('gan-320-9')
In [19]:
learn = refresh_gan('gan-320-9', bs=bs//6, size=320, p=0.8)
gc.collect()
learn.fit(1, lr/8)
epoch train_loss valid_loss gen_loss disc_loss time
0 2.966984 3.132009 6.156664 0.569404 4:52:51
In [ ]:
learn.save('gan-320-10')
In [19]:
learn = refresh_gan('gan-320-10', bs=bs//6, size=320, p=0.8)
gc.collect()
learn.fit(1, lr)
epoch train_loss valid_loss gen_loss disc_loss time
0 3.349297 3.289146 6.526690 0.541902 4:52:26
In [20]:
learn.save('gan-320-11')
In [19]:
learn = refresh_gan('gan-320-11', bs=bs//6, size=320, p=0.8)
gc.collect()
learn.fit(1, lr)
epoch train_loss valid_loss gen_loss disc_loss time
0 3.385822 3.279479 6.984418 0.571622 4:54:53
In [20]:
learn.save('gan-320-12')
In [21]:
learn = refresh_gan('gan-320-12', bs=bs//6, size=320, p=0.8)
gc.collect()
learn.fit(1, lr)
epoch train_loss valid_loss gen_loss disc_loss time
0 3.031612 3.294229 6.359600 0.569556 4:44:02
In [22]:
learn.save('gan-320-13')
In [19]:
learn = refresh_gan('gan-320-13', bs=bs//6, size=320, p=0.8)
gc.collect()
learn.fit(1, lr)
epoch train_loss valid_loss gen_loss disc_loss time
0 3.432154 3.127635 6.828905 0.558253 4:42:32
In [20]:
learn.save('gan-320-14')
In [19]:
learn = refresh_gan('gan-320-14', bs=bs//6, size=320, p=0.8)
gc.collect()
learn.fit(1, lr)
epoch train_loss valid_loss gen_loss disc_loss time
0 3.430166 3.374661 6.778205 0.539977 4:42:42
In [20]:
learn.save('gan-320-15')
In [19]:
learn = refresh_gan('gan-320-15', bs=bs//6, size=320, p=0.8)
gc.collect()
learn.fit(1, lr)
epoch train_loss valid_loss gen_loss disc_loss time
0 3.535818 3.270476 6.796846 0.551929 4:47:44
In [20]:
learn.save('gan-320-16')
In [19]:
learn = refresh_gan('gan-320-16', bs=bs//6, size=320, p=0.8)
gc.collect()
learn.fit(1, lr)
epoch train_loss valid_loss gen_loss disc_loss time
0 3.120983 2.853583 6.244678 0.529247 4:51:41
In [20]:
learn.save('gan-320-17') # good results here
In [19]:
learn = refresh_gan('gan-320-17', bs=bs//6, size=320, p=0.8)
gc.collect()
learn.fit(1, lr)
epoch train_loss valid_loss gen_loss disc_loss time
0 3.170357 3.260084 6.195094 0.544852 4:51:50
In [20]:
learn.save('gan-320-18')
In [19]:
learn = refresh_gan('gan-320-18', bs=bs//6, size=320, p=0.8)
gc.collect()
learn.fit(1, lr)
epoch train_loss valid_loss gen_loss disc_loss time
0 3.251461 3.273478 6.225502 0.544857 4:51:03
In [20]:
learn.save('gan-320-19')
In [19]:
learn = refresh_gan('gan-320-19', bs=bs//6, size=320, p=0.8)
gc.collect()
learn.fit(1, lr)
epoch train_loss valid_loss gen_loss disc_loss time
0 3.135075 3.205379 6.286476 0.549460 3:24:58
In [20]:
learn.save('gan-320-20') #good results
In [19]:
learn = refresh_gan('gan-320-20', bs=bs//6, size=320, p=0.8)
gc.collect()
learn.fit(1, lr)
epoch train_loss valid_loss gen_loss disc_loss time
0 3.015370 3.284572 6.970301 0.595093 4:51:47
In [20]:
learn.save('gan-320-21')
In [19]:
learn = refresh_gan('gan-320-21', bs=bs//6, size=320, p=0.4)
gc.collect()
learn.fit(3, lr)
epoch train_loss valid_loss gen_loss disc_loss time
0 3.223896 3.060349 6.083280 0.527072 1:42:28
1 3.524580 3.251389 7.184547 0.577645 1:42:10
2 2.992206 2.861572 6.631731 0.589967 1:42:29
In [20]:
learn.save('gan-320-22')
In [19]:
learn = refresh_gan('gan-320-22', bs=bs//6, size=320, p=0.8)
gc.collect()
learn.fit(1, lr)
epoch train_loss valid_loss gen_loss disc_loss time
0 3.051067 3.132140 5.976385 0.551725 4:49:21
In [20]:
learn.save('gan-320-23')
In [ ]:
learn = refresh_gan('gan-320-23', bs=bs//6, size=320, p=0.8)
gc.collect()
learn.fit(1, lr)
In [20]:
learn.save('gan-320-24') # pretty decent
In [25]:
learn.show_results(rows=6)
In [ ]:
learn = refresh_gan('gan-320-24', bs=6, size=320, p=0.2)
In [112]:
i = 20
open_image(path_color.ls()[i])
Out[112]:
In [113]:
open_image(path_gray.ls()[i])
Out[113]:
In [114]:
gray_img = open_image(path_gray.ls()[i])
learn.predict(gray_img)[0]
Out[114]:

Upload to DropBox

In [84]:
learn = refresh_gan('gan-320-24', bs=bs//6, size=320, p=0.8)
save_path = Path('/storage/capstone/colorizer')
fp = learn.save(save_path, return_path=True, with_opt=True) # pretty decent
fp = str(fp)
In [25]:
# export model
learn.export('colorizer.pkl')
fp = learn.path/'colorizer.pkl'
In [86]:
# modified solution found at: https://stackoverflow.com/questions/37397966/dropbox-api-v2-upload-large-files-using-python
def upload(access_token, file_path, target_path, timeout=900, chunk_size=4*1024*1024):
    dbx = dropbox.Dropbox(access_token, timeout=timeout)
    with open(file_path, "rb") as f:
        file_size = os.path.getsize(file_path)
        if file_size <= chunk_size:
            print(dbx.files_upload(f.read(), target_path))
        else:
            with tqdm(total=file_size, desc="Uploaded") as pbar:
                start_result = dbx.files_upload_session_start(f.read(chunk_size))
                pbar.update(chunk_size)
                cursor = dropbox.files.UploadSessionCursor(session_id=start_result.session_id, offset=f.tell())
                commit = dropbox.files.CommitInfo(path=target_path)
                while f.tell() < file_size:
                    if (file_size - f.tell()) <= chunk_size:
                        print(dbx.files_upload_session_finish(f.read(chunk_size), cursor, commit))
                    else:
                        dbx.files_upload_session_append(
                            f.read(chunk_size),
                            cursor.session_id,
                            cursor.offset,
                        )
                        cursor.offset = f.tell()
                    pbar.update(chunk_size)
In [87]:
key = "Wx9jfJEGYHsAAAAAAAAAU_4bjtXXeDSNPqylOcF_k6q4VB-lzxJ6vv_WPwxDVIQ2"
db_path = "/Apps/ImgBuff/colorizer.pth"
upload(key, fp, db_path)
Uploaded: 507510784it [01:36, 4987357.23it/s]                               
FileMetadata(name='colorizer.pth', id='id:V6kESqQ-glAAAAAAAAAAIA', client_modified=datetime.datetime(2020, 6, 22, 3, 8, 5), server_modified=datetime.datetime(2020, 6, 22, 3, 8, 5), rev='015a8a38d31783900000001d093aa70', size=506122862, path_lower='/apps/imgbuff/colorizer.pth', path_display='/Apps/ImgBuff/colorizer.pth', parent_shared_folder_id=None, media_info=None, symlink_info=None, sharing_info=None, is_downloadable=True, export_info=None, property_groups=None, has_explicit_shared_members=None, content_hash='031cb54de6ece45e7842b7f79373f70ec4f0e230c26644b5ed0d2c819a55c2fd', file_lock_info=None)

In [42]:
gray = '/storage/capstone/grayscale/000000354235.jpg'
color = '/storage/capstone/color/000000354235.jpg'
In [43]:
key = "Wx9jfJEGYHsAAAAAAAAAU_4bjtXXeDSNPqylOcF_k6q4VB-lzxJ6vv_WPwxDVIQ2"
db_path = "/Apps/ImgBuff/grayscale/000000354235.jpg"
upload(key, gray, db_path)
db_path = "/Apps/ImgBuff/color/000000354235.jpg"
upload(key, color, db_path)
FileMetadata(name='000000354235.jpg', id='id:V6kESqQ-glAAAAAAAAAAGg', client_modified=datetime.datetime(2020, 6, 22, 1, 16, 50), server_modified=datetime.datetime(2020, 6, 22, 1, 16, 50), rev='015a8a1ff57fc0d00000001d093aa70', size=183867, path_lower='/apps/imgbuff/grayscale/000000354235.jpg', path_display='/Apps/ImgBuff/grayscale/000000354235.jpg', parent_shared_folder_id=None, media_info=None, symlink_info=None, sharing_info=None, is_downloadable=True, export_info=None, property_groups=None, has_explicit_shared_members=None, content_hash='1d481cbbb22fba748a845f238539cd30ed45671287d74faf4776e0516dbb45f3', file_lock_info=None)
FileMetadata(name='000000354235.jpg', id='id:V6kESqQ-glAAAAAAAAAAGw', client_modified=datetime.datetime(2020, 6, 22, 1, 16, 52), server_modified=datetime.datetime(2020, 6, 22, 1, 16, 52), rev='015a8a1ff746fbf00000001d093aa70', size=203763, path_lower='/apps/imgbuff/color/000000354235.jpg', path_display='/Apps/ImgBuff/color/000000354235.jpg', parent_shared_folder_id=None, media_info=None, symlink_info=None, sharing_info=None, is_downloadable=True, export_info=None, property_groups=None, has_explicit_shared_members=None, content_hash='7db0c7d23ed3fb997b37424233f21421aecea55fae009f4ce29df27d968176f0', file_lock_info=None)
In [ ]: